import torch
import torch.nn as nn

gru = nn.GRU(5, 6, 2)
input_ = torch.randn(1, 3, 5)
h0 = torch.randn(2, 3, 6)
output, hn = gru(input_, h0)

print(output)
print(output.shape)
print(hn)
print(hn.shape)
